from typing import List 

class Solution:
    def numTrees(self, n: int) -> int:
        dp = [0] * (n+1)
        # initialize
        dp[0] = 1
        for i in range(1, n+1):
            for j in range(1, i+1):
                dp[i] += dp[j-1] * dp[i-j]
        return dp[n]

if __name__ == "__main__":
    s = Solution()
    n = 3 
    assert s.numTrees(n) == 5
